import numpy
import matplotlib.pyplot as pyplot
import scipy.constants as constants
import scipy.optimize as optimise
import os
from jqc import jqc_plot
from scipy.special import erf as erf
from matplotlib.patches import ConnectionPatch

cwd = os.path.dirname(os.path.abspath(__file__))
jqc_plot.plot_style('normal')

point_blue = jqc_plot.colours['grayblue']
line_blue  = jqc_plot.colours['blue']

point_red = jqc_plot.colours['reddish']
line_red  = jqc_plot.colours['red']

#scipy's erf foes from -1 to 1, we want to go from 0 to 1 so I rescale and allow
#the user to pass a mean and standard deviation at the same time.
erf_scaled  = lambda x,x0,w: 0.5*(1+erf((x-x0)/(numpy.sqrt(2)*w)))

pi = numpy.pi
c  = constants.c
e0 = constants.epsilon_0

WattsPerVolt = 0.985
convBohr3     = 4.68645e-8 #(MHz/W/m^2)/a0^3
Intensity     = 8.6e3 #w/m^2
Intensity_err =  0.07e3 #W/m^2

def chi_squared(x,y,alpha,model,pars):
    X = (y-model(x,*pars))**2/alpha**2
    X2 = numpy.sum(X)
    return X2

def correlation_matrix(cov):
    corr= numpy.zeros(cov.shape)
    for i in range(cov.shape[0]):
        for j in range(cov.shape[1]):
            corr[i,j]=cov[i,j]/numpy.sqrt(cov[i,i]*cov[j,j])
    return corr

def Chauvenet(x,mean,std,N):
    ''' implements the Chauvenet criterion for checking for outliers from hughes
    and hase. Takes inputs to describe a gaussian distribution for a variable
    and a value for that variable that is thought to be an outlier. Returns a
    Boolean corresponding to whether that value passes or fails the test'''

    #calculate the variation from the mean value for the input point
    xout = x-mean

    #what is the probability of this happening purely at random given the
    #current estimate of the distribution?
    P_out = 1-(erf_scaled(mean+xout,mean,std)-erf_scaled(mean-xout,mean,std))

    #How many points that are this far (or further) do we expect to see?
    N_outs = N*P_out

    #return something that is somewhat useful
    if N_outs >= 1/2:
        #test is passed
        return True
    else:
        #if the test is failed return false
        return False

def Gaussian(x,x0,w,A):
    ''' returns an area normalised gaussian function for fitting, the parameters
    are A - Area, w- standard deviation and x0 - centre'''

    G = A*numpy.exp(-(x-x0)**2/(2*w**2))
    return G

def Gaussian_derivative(var,x,x0,w,A):
    var = var.lower()
    if var == "mean" or var =='centre':
        return (A*(x - x0))/(numpy.exp((x - x0)**2/(2.*w0**2))*w0**2)
    elif var =="width":
        return (A*(x - x0)**2)/(numpy.exp((x - x0)**2/(2.*w0**2))*w0**3)
    elif var == "area":
        return numpy.exp(-(x - x0)**2/(2.*w0**2))

def average_data(dataset,use_chavenet = False):
    ''' sorts a list of data (y) into a sequential list with mean and standard
    deviations for a given variable x. Data is returned in structure:
    x   mean(y)     std(y)'''

    #sort input data into ascending order on variables
    sorting = numpy.argsort(dataset[:,0])
    dataset= dataset[sorting,:]

    #find the unique elements in the variable column x
    x           = dataset[:,0]
    order,index = numpy.unique(x,return_index=True)

    #initalise the first index to slice from (this should always==0)
    last_index = index[0]

    #initialise an array for stacking
    output = numpy.array([0,0,0])

    #iterate over all the indices that show unique x values
    for i in index[1:]:
        y    = dataset[last_index:i,1]
        mean = numpy.mean(y)
        std  = numpy.std(y) + 1e-10 #add on a small number so that if there is
                                    #no variation any divisions or manipulations
                                    #don't crash the whole script

        if use_chavenet:
            locs = [False]

            #loop over each y until there are no outliers
            while not numpy.array(locs).any():

                locs = [Chauvenet(x,mean,std,len(y)) for x in y]
                y    = y[locs] #remove the outliers

                #recalculate the mean + standard deviation for the new y
                mean = numpy.mean(y)
                std  = numpy.std(y) + 1e-10

        output = numpy.vstack((output,numpy.array([x[last_index],mean,std])))
        last_index = i

    #because of the way I have implemented slicing the last index will not be
    #done, I'll implement it here as a special case that goes to the end of the
    #array.
    y      = dataset[index[-1]:]
    mean   = numpy.mean(y)
    std    = numpy.std(y)
    output = numpy.vstack((output,numpy.array([x[index[-1]],mean,std])))

    #finally I slice out the initialisation for output and return it
    output = output[1:,:]
    return output
########### FIRST DATA SET ###################
#### FITTING ####
#import data
data_free = numpy.genfromtxt(cwd+r"\Data\Stokes\P=0W_WL=1064.519(2).csv",
                            delimiter=',')

#find average of data using fn above
avg_free = average_data(data_free,False)

test = [457,.15,2.5e3]
#fit to a gaussian function
par_free,cov_free = optimise.curve_fit(Gaussian,data_free[:,0],data_free[:,1],
                    p0=test)


#calculate error as diagonal of covariance matrix
err_free = numpy.sqrt(numpy.diag(cov_free))

#repeat for 2nd dataset
data_beam = numpy.genfromtxt(cwd+r"\Data\Stokes\P=5.2W_WL=1064.519(2).csv",
                            delimiter=',')

avg_beam = average_data(data_beam,False)

test = [457.16,.15,2.5e3]

par_beam,cov_beam = optimise.curve_fit(Gaussian,data_beam[:,0],data_beam[:,1],
                    p0=test)


err_beam = numpy.sqrt(numpy.diag(cov_beam))

#return the results
print("Free-Space:",par_free,err_free)
print("pct error in shift",100*err_free[0]/par_free[1])
locs = numpy.where(avg_free[:,2]!=1e-10)[0]
print("x^2=",chi_squared(avg_free[locs,0],avg_free[locs,1],avg_free[locs,2],
                                            Gaussian,par_free)/len(par_free))

print("Beam:",par_beam,err_beam)
print(100*err_beam[0]/par_beam[1])
locs = numpy.where(avg_beam[:,2]!=1e-10)[0]

print("x^2=",chi_squared(avg_beam[locs,0],avg_beam[locs,1],avg_beam[locs,2],
                                            Gaussian,par_beam)/len(par_beam))
###CALCULATE SHIFT FOR POLARISABILITY CALCULATIONS ####

shift     = par_free[0]-par_beam[0]
shift_err = numpy.sqrt(err_free[0]**2+err_beam[0]**2)

print("Shift in MHz:",shift,shift_err)

#print("pct shift",100*shift/numpy.max([par_beam[1],par_free[1]]))
print("pct error in shift",100*shift_err/numpy.max([par_beam[1],par_free[1]]))

polarisability     = shift/Intensity
polarisability_err = polarisability*numpy.sqrt((shift_err/shift)**2+
                                                (Intensity_err/Intensity)**2)

print("Gradient:",polarisability*1e6,1e6*polarisability_err)

print("polarisability (a0^3)",
        polarisability/convBohr3,polarisability_err/convBohr3)

##### PLOTTING #####
#initialise figures
fig = pyplot.figure("Spectroscopy")
ax1  = fig.add_subplot(211)
ax2  = fig.add_subplot(212,sharex=ax1)


#plot averaged data
ax1.errorbar(avg_free[:,0]-par_free[0],1e-3*avg_free[:,1],yerr=1e-3*avg_free[:,2],
            color=point_blue,zorder=1.6,fmt='o',capsize=3.5)

ax1.errorbar(avg_beam[:,0]-par_free[0],1e-3*avg_beam[:,1],yerr=1e-3*avg_beam[:,2],
            color=point_red,fmt='o',zorder=1.7,capsize=3.5)

#plot functions
freq = numpy.linspace(456,458,1500)

ax1.plot(freq-par_free[0],1e-3*Gaussian(freq,*par_free),color=line_blue,
        label='trap off',zorder=1.2)

ax1.plot(freq-par_free[0],1e-3*Gaussian(freq,*par_beam),color=line_red,
        label = "trap on",zorder=1.3)


#ax1.set_xlabel("Molecule Number")
ax1.set_ylabel("$N_\\mathrm{mol}\\;(\\times 10^3)$")
ax1.set_ylim(-0.050,3.350)
xmin,xmax = ax1.get_xlim()
ymin,ymax = ax1.get_ylim()
#ax1.invert_xaxis()

#ax1.text(0.05,0.9,"(a)",fontsize=20,transform=ax1.transAxes)
ax1.text(0.92,0.85,"(a)",fontsize=20,transform=ax1.transAxes)
ax1.text(0.05,0.7,"$f_0-12.3$ GHz",fontsize=15,transform=ax1.transAxes)

l = ConnectionPatch(xyA=(par_beam[0]-par_free[0],ymin),
                    xyB=(par_beam[0]-par_free[0],ymax-0.05),
                    coordsA='data',coordsB='data',
                    axesA=ax2,axesB=ax1,
                    color=jqc_plot.colours['red'],ls='--',lw=2,zorder=1.4)
ax2.add_artist(l)

########### Second DATA SET ###################
#### FITTING ####
#import data
data_free = numpy.genfromtxt(cwd+r"\Data\Stokes\P=0W_WL=1064.4411.csv",
                            delimiter=',')

#find average of data using fn above
avg_free = average_data(data_free,False)

test = [457,.15,2.5e3]
#fit to a gaussian function
par_free,cov_free = optimise.curve_fit(Gaussian,data_free[:,0],data_free[:,1],
                    p0=test)


#calculate error as diagonal of covariance matrix
err_free = numpy.sqrt(numpy.diag(cov_free))

#repeat for 2nd dataset
data_beam = numpy.genfromtxt(cwd+r"\Data\Stokes\P=5.2W_WL=1064.4411.csv",
                            delimiter=',')

avg_beam = average_data(data_beam,False)

test = [457.16,.15,2.5e3]

par_beam,cov_beam = optimise.curve_fit(Gaussian,data_beam[:,0],data_beam[:,1],
                    p0=test)


err_beam = numpy.sqrt(numpy.diag(cov_beam))

#return the results
print("Free-Space:",par_free,err_free)
print("pct error in shift",100*err_free[0]/par_free[1])
locs = numpy.where(avg_free[:,2]!=1e-10)[0]
print("x^2=",chi_squared(avg_free[locs,0],avg_free[locs,1],avg_free[locs,2],
                                            Gaussian,par_free)/len(par_free))

print("Beam:",par_beam,err_beam)
#print(100*err_beam[0]/par_beam[1])
locs = numpy.where(avg_beam[:,2]!=1e-10)[0]

print("x^2=",chi_squared(avg_beam[locs,0],avg_beam[locs,1],avg_beam[locs,2],
                                            Gaussian,par_beam)/len(par_beam))
###CALCULATE SHIFT FOR POLARISABILITY CALCULATIONS ####

shift     = par_free[0]-par_beam[0]
shift_err = numpy.sqrt(err_free[0]**2+err_beam[0]**2)

print("Shift in MHz:",shift,shift_err)

print("pct shift",100*shift/numpy.max([par_beam[1],par_free[1]]))
print("pct error in shift",100*shift_err/numpy.max([par_beam[1],par_free[1]]))

polarisability     = shift/Intensity
polarisability_err = polarisability*numpy.sqrt((shift_err/shift)**2+
                                                (Intensity_err/Intensity)**2)


print("Gradient:",1e6*polarisability,1e6*polarisability_err)
print("polarisability (a0^3)",
        polarisability/convBohr3,polarisability_err/convBohr3)

##### PLOTTING #####
#initialise figure

#plot averaged data
ax2.errorbar(avg_free[:,0]-par_free[0],1e-3*avg_free[:,1],yerr=1e-3*avg_free[:,2],
            color=point_blue,fmt='o',zorder=1.6,capsize=3.5)

ax2.errorbar(avg_beam[:,0]-par_free[0],1e-3*avg_beam[:,1],yerr=1e-3*avg_beam[:,2],
                color=point_red,fmt='o',zorder=1.7,capsize=3.5)

#plot functions
freq = numpy.linspace(456,458,1500)

ax2.plot(freq-par_free[0],1e-3*Gaussian(freq,*par_free),color=line_blue,
        label='trap off',zorder=1.2)

ax2.plot(freq-par_free[0],1e-3*Gaussian(freq,*par_beam),color=line_red,
        label = "trap on",zorder=1.3)


#ax2.text(0.05,0.9,"(b)",fontsize=20,transform=ax2.transax2es)
ax2.text(0.92,0.85,"(b)",fontsize=20,transform=ax2.transAxes)
ax2.text(0.05,0.7,"$f_0 +8.4$ GHz",fontsize=15,transform=ax2.transAxes)

ax2.set_xlabel("STIRAP Two-photon Detuning (MHz)")
ax2.set_ylim(ymin,ymax)
ax2.set_xlim(-0.6,0.5)

ax2.set_ylabel("$N_\\mathrm{mol}\\;(\\times 10^3)$")


l = ConnectionPatch(xyA=(par_beam[0]-par_free[0],ymin),
                    xyB=(par_beam[0]-par_free[0],ymax-0.05),
                    coordsA='data',coordsB='data',
                    axesA=ax2,axesB=ax1,
                    color=jqc_plot.colours['red'],ls='--',lw=2,zorder=1.)
ax2.add_artist(l)

l = ConnectionPatch(xyA=(0,ymin),
                    xyB=(0,ymax-0.05),
                    coordsA='data',coordsB='data',
                    axesA=ax2,axesB=ax1,
                    color=jqc_plot.colours['blue'],ls=':',lw=2,zorder=1.)
ax2.add_artist(l)


#pyplot.setp(ax2.get_yticklabels(),visible=False)
pyplot.setp(ax1.get_xticklabels(),visible=False)


#pyplot.legend()
pyplot.tight_layout()
pyplot.subplots_adjust(hspace=0)
pyplot.savefig(cwd+r"\Fit_Stokes_spec_2.pdf")
pyplot.savefig(cwd+r"\Fit_Stokes_spec_2.png")
pyplot.show()
